{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "MF7BncmmLBeO"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "from sklearn.datasets import load_digits\n",
    "from sklearn import datasets\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torchvision.transforms as tt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**DISCLAIMER**\n",
    "\n",
    "The presented code is not optimized, it serves an educational purpose. It is written for CPU, it uses only fully-connected networks and an extremely simplistic dataset. However, it contains all components that can help to understand how an energy-based model (EBM) works, and it should be rather easy to extend it to more sophisticated models. This code could be run almost on any laptop/PC, and it takes a couple of minutes top to get the result."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "RKsmjLumL5A2"
   },
   "source": [
    "## Dataset: Digits"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this example, we go wild and use a dataset that is simpler than MNIST! We use a scipy dataset called Digits. It consists of ~1500 images of size 8x8, and each pixel can take values in $\\{0, 1, \\ldots, 16\\}$.\n",
    "\n",
    "The goal of using this dataset is that everyone can run it on a laptop, without any gpu etc."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "hSWUnXAYLLif"
   },
   "outputs": [],
   "source": [
    "class Digits(Dataset):\n",
    "    \"\"\"Scikit-Learn Digits dataset.\"\"\"\n",
    "\n",
    "    def __init__(self, mode='train', transforms=None):\n",
    "        digits = load_digits()\n",
    "        if mode == 'train':\n",
    "            self.data = digits.data[:1000].astype(np.float32)\n",
    "            self.targets = digits.target[:1000]\n",
    "        elif mode == 'val':\n",
    "            self.data = digits.data[1000:1350].astype(np.float32)\n",
    "            self.targets = digits.target[1000:1350]\n",
    "        else:\n",
    "            self.data = digits.data[1350:].astype(np.float32)\n",
    "            self.targets = digits.target[1350:]\n",
    "\n",
    "        self.transforms = transforms\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.data)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        sample_x = self.data[idx]\n",
    "        sample_y = self.targets[idx]\n",
    "        if self.transforms:\n",
    "            sample_x = self.transforms(sample_x)\n",
    "        return (sample_x, sample_y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "qSP2qiMqMICK"
   },
   "source": [
    "## Energy-based Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "GRYA6JA4LWEC"
   },
   "outputs": [],
   "source": [
    "class EBM(nn.Module):\n",
    "    def __init__(self, energy_net, alpha, sigma, ld_steps, D):\n",
    "        super(EBM, self).__init__()\n",
    "\n",
    "        print('EBM by JT.')\n",
    "\n",
    "        # the neural net used by the EBM\n",
    "        self.energy_net = energy_net\n",
    "\n",
    "        # the loss for classification\n",
    "        self.nll = nn.NLLLoss(reduction='none')  # it requires log-softmax as input!!\n",
    "\n",
    "        # hyperparams\n",
    "        self.D = D\n",
    "\n",
    "        self.sigma = sigma\n",
    "\n",
    "        self.alpha = torch.FloatTensor([alpha])\n",
    "\n",
    "        self.ld_steps = ld_steps\n",
    "\n",
    "    def classify(self, x):\n",
    "        f_xy = self.energy_net(x)\n",
    "        y_pred = torch.softmax(f_xy, 1)\n",
    "        return torch.argmax(y_pred, dim=1)\n",
    "\n",
    "    def class_loss(self, f_xy, y):\n",
    "        # - calculate logits (for classification)\n",
    "        y_pred = torch.softmax(f_xy, 1)\n",
    "\n",
    "        return self.nll(torch.log(y_pred), y)\n",
    "\n",
    "    def gen_loss(self, x, f_xy):\n",
    "        # - sample using Langevine dynamics\n",
    "        x_sample = self.sample(x=None, batch_size=x.shape[0])\n",
    "\n",
    "        # - calculate f(x_sample)[y]\n",
    "        f_x_sample_y = self.energy_net(x_sample)\n",
    "\n",
    "        return -(torch.logsumexp(f_xy, 1) - torch.logsumexp(f_x_sample_y, 1))\n",
    "\n",
    "    def forward(self, x, y, reduction='avg'):\n",
    "        # =====\n",
    "        # forward pass through the network\n",
    "        # - calculate f(x)[y]\n",
    "        f_xy = self.energy_net(x)\n",
    "\n",
    "        # =====\n",
    "        # discriminative part\n",
    "        # - calculate the discriminative loss: the cross-entropy\n",
    "        L_clf = self.class_loss(f_xy, y)\n",
    "\n",
    "        # =====\n",
    "        # generative part\n",
    "        # - calculate the generative loss: E(x) - E(x_sample)\n",
    "        L_gen = self.gen_loss(x, f_xy)\n",
    "\n",
    "        # =====\n",
    "        # Final objective\n",
    "        if reduction == 'sum':\n",
    "            loss = (L_clf + L_gen).sum()\n",
    "        else:\n",
    "            loss = (L_clf + L_gen).mean()\n",
    "\n",
    "        return loss\n",
    "\n",
    "    def energy_gradient(self, x):\n",
    "        self.energy_net.eval()\n",
    "\n",
    "        # copy original data that doesn't require grads!\n",
    "        x_i = torch.FloatTensor(x.data)\n",
    "        x_i.requires_grad = True  # WE MUST ADD IT, otherwise autograd won't work\n",
    "\n",
    "        # calculate the gradient\n",
    "        x_i_grad = torch.autograd.grad(torch.logsumexp(self.energy_net(x_i), 1).sum(), [x_i], retain_graph=True)[0]\n",
    "\n",
    "        self.energy_net.train()\n",
    "\n",
    "        return x_i_grad\n",
    "\n",
    "    def langevine_dynamics_step(self, x_old, alpha):\n",
    "        # Calculate gradient wrt x_old\n",
    "        grad_energy = self.energy_gradient(x_old)\n",
    "        # Sample eta ~ Normal(0, alpha)\n",
    "        epsilon = torch.randn_like(grad_energy) * self.sigma\n",
    "\n",
    "        # New sample\n",
    "        x_new = x_old + alpha * grad_energy + epsilon\n",
    "\n",
    "        return x_new\n",
    "\n",
    "    def sample(self, batch_size=64, x=None):\n",
    "        # - 1) Sample from uniform\n",
    "        x_sample = 2. * torch.rand([batch_size, self.D]) - 1.\n",
    "\n",
    "        # - 2) run Langevine Dynamics\n",
    "        for i in range(self.ld_steps):\n",
    "            x_sample = self.langevine_dynamics_step(x_sample, alpha=self.alpha)\n",
    "\n",
    "        return x_sample"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "vUoPkTmrMVnx"
   },
   "source": [
    "## Evaluation and Training functions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "JvwmRoi7MVto"
   },
   "source": [
    "**Evaluation step, sampling and curve plotting**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "JHx4RIqDLZe9"
   },
   "outputs": [],
   "source": [
    "def evaluation(test_loader, name=None, model_best=None, epoch=None):\n",
    "    # EVALUATION\n",
    "    if model_best is None:\n",
    "        # load best performing model\n",
    "        model_best = torch.load(name + '.model')\n",
    "\n",
    "    model_best.eval()\n",
    "    loss = 0.\n",
    "    loss_error = 0.\n",
    "    loss_gen = 0.\n",
    "    N = 0.\n",
    "    for indx_batch, (test_batch, test_targets) in enumerate(test_loader):\n",
    "        # hybrid loss\n",
    "        loss_t = model_best.forward(test_batch, test_targets, reduction='sum')\n",
    "        loss = loss + loss_t.item()\n",
    "        # classification error\n",
    "        y_pred = model_best.classify(test_batch)\n",
    "        e = 1.*(y_pred == test_targets)\n",
    "        loss_error = loss_error + (1. - e).sum().item()\n",
    "        # generative nll\n",
    "        f_xy_test = model_best.energy_net(test_batch)\n",
    "        loss_gen = loss_gen + model_best.gen_loss(test_batch, f_xy_test).sum()\n",
    "        # the number of examples\n",
    "        N = N + test_batch.shape[0]\n",
    "    loss = loss / N\n",
    "    loss_error = loss_error / N\n",
    "    loss_gen = loss_gen / N\n",
    "\n",
    "    if epoch is None:\n",
    "        print(f'FINAL PERFORMANCE: nll={loss}, ce={loss_error}, gen_nll={loss_gen}')\n",
    "    else:\n",
    "        print(f'Epoch: {epoch}, val nll={loss}, val ce={loss_error}, val gen_nll={loss_gen}')\n",
    "\n",
    "    return loss, loss_error, loss_gen\n",
    "\n",
    "\n",
    "def samples_real(name, test_loader):\n",
    "    # REAL-------\n",
    "    num_x = 4\n",
    "    num_y = 4\n",
    "    x, _ = next(iter(test_loader))\n",
    "    x = x.detach().numpy()\n",
    "\n",
    "    fig, ax = plt.subplots(num_x, num_y)\n",
    "    for i, ax in enumerate(ax.flatten()):\n",
    "        plottable_image = np.reshape(x[i], (8, 8))\n",
    "        ax.imshow(plottable_image, cmap='gray')\n",
    "        ax.axis('off')\n",
    "\n",
    "    plt.savefig(name+'_real_images.pdf', bbox_inches='tight')\n",
    "    plt.close()\n",
    "\n",
    "\n",
    "def samples_generated(name, data_loader, extra_name=''):\n",
    "    # GENERATIONS-------\n",
    "    model_best = torch.load(name + '.model')\n",
    "    model_best.eval()\n",
    "\n",
    "    num_x = 4\n",
    "    num_y = 4\n",
    "    x = model_best.sample(num_x * num_y)\n",
    "    x = x.detach().numpy()\n",
    "\n",
    "    fig, ax = plt.subplots(num_x, num_y)\n",
    "    for i, ax in enumerate(ax.flatten()):\n",
    "        plottable_image = np.reshape(x[i], (8, 8))\n",
    "        ax.imshow(plottable_image, cmap='gray')\n",
    "        ax.axis('off')\n",
    "\n",
    "    plt.savefig(name + '_generated_images' + extra_name + '.pdf', bbox_inches='tight')\n",
    "    plt.close()\n",
    "\n",
    "\n",
    "def plot_curve(name, nll_val, file_name='_nll_val_curve.pdf', color='b-'):\n",
    "    plt.plot(np.arange(len(nll_val)), nll_val, color, linewidth='3')\n",
    "    plt.xlabel('epochs')\n",
    "    plt.ylabel('nll')\n",
    "    plt.savefig(name + file_name, bbox_inches='tight')\n",
    "    plt.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "umU3VYKzMbDt"
   },
   "source": [
    "**Training step**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "NxkUZ1xVLbm_"
   },
   "outputs": [],
   "source": [
    "def training(name, max_patience, num_epochs, model, optimizer, training_loader, val_loader):\n",
    "    nll_val = []\n",
    "    gen_val = []\n",
    "    error_val = []\n",
    "    best_nll = 1000.\n",
    "    patience = 0\n",
    "\n",
    "    # Main loop\n",
    "    for e in range(num_epochs):\n",
    "        # TRAINING\n",
    "        model.train()\n",
    "        for indx_batch, (batch, targets) in enumerate(training_loader):\n",
    "\n",
    "            loss = model.forward(batch, targets)\n",
    "\n",
    "            optimizer.zero_grad()\n",
    "            loss.backward(retain_graph=True)\n",
    "            optimizer.step()\n",
    "\n",
    "        # Validation\n",
    "        loss_e, error_e, gen_e = evaluation(val_loader, model_best=model, epoch=e)\n",
    "        nll_val.append(loss_e)  # save for plotting\n",
    "        gen_val.append(gen_e)  # save for plotting\n",
    "        error_val.append(error_e)  # save for plotting\n",
    "\n",
    "        if e == 0:\n",
    "            print('saved!')\n",
    "            torch.save(model, name + '.model')\n",
    "            best_nll = loss_e\n",
    "        else:\n",
    "            if loss_e < best_nll:\n",
    "                print('saved!')\n",
    "                torch.save(model, name + '.model')\n",
    "                best_nll = loss_e\n",
    "                patience = 0\n",
    "\n",
    "                samples_generated(name, val_loader, extra_name=\"_epoch_\" + str(e))\n",
    "            else:\n",
    "                patience = patience + 1\n",
    "\n",
    "        if patience > max_patience:\n",
    "            break\n",
    "\n",
    "    nll_val = np.asarray(nll_val)\n",
    "    error_val = np.asarray(error_val)\n",
    "    gen_val = np.asarray(gen_val)\n",
    "\n",
    "    return nll_val, error_val, gen_val"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "0BXJ9dN0MinB"
   },
   "source": [
    "## Experiments"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "KsF7f-Q-MkWu"
   },
   "source": [
    "**Initialize datasets**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "transforms_train = tt.Compose( [tt.Lambda(lambda x: 2. * (x / 17.) - 1.),\n",
    "                              tt.Lambda(lambda x: torch.from_numpy(x)),\n",
    "                              tt.Lambda(lambda x: x + 0.03 * torch.randn_like(x))\n",
    "                              ])\n",
    "\n",
    "transforms_val  = tt.Compose( [tt.Lambda(lambda x: 2. * (x / 17.) - 1.),\n",
    "                               tt.Lambda(lambda x: torch.from_numpy(x)),\n",
    "                               ])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "fqZKMNM0LdQ1"
   },
   "outputs": [],
   "source": [
    "train_data = Digits(mode='train', transforms=transforms_train)\n",
    "val_data = Digits(mode='val', transforms=transforms_val)\n",
    "test_data = Digits(mode='test', transforms=transforms_val)\n",
    "\n",
    "training_loader = DataLoader(train_data, batch_size=64, shuffle=True)\n",
    "val_loader = DataLoader(val_data, batch_size=64, shuffle=False)\n",
    "test_loader = DataLoader(test_data, batch_size=64, shuffle=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "6lEKUznpMns7"
   },
   "source": [
    "**Hyperparameters**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ANQo7LrGLjIN"
   },
   "outputs": [],
   "source": [
    "D = 64  # input dimension\n",
    "K = 10 # output dimension\n",
    "M = 512  # the number of neurons\n",
    "\n",
    "sigma = 0.01 # the noise level\n",
    "\n",
    "alpha = 1.  # the step-size for SGLD\n",
    "ld_steps = 20  # the number of steps of SGLD\n",
    "\n",
    "lr = 1e-3  # learning rate\n",
    "num_epochs = 1000  # max. number of epochs\n",
    "max_patience = 20  # an early stopping is used, if training doesn't improve for longer than 20 epochs, it is stopped"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "-7APXeunMrDh"
   },
   "source": [
    "**Creating a folder for results**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "bjSUn1eWLkWm"
   },
   "outputs": [],
   "source": [
    "name = 'ebm' + '_' + str(alpha) + '_' + str(sigma) + '_' + str(ld_steps)\n",
    "result_dir = 'results/' + name + '/'\n",
    "if not (os.path.exists(result_dir)):\n",
    "    os.mkdir(result_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Hpwm6LWUMulQ"
   },
   "source": [
    "**Initializing the model**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "FrnNsCqQLmK3",
    "outputId": "5f0cf2b1-0a96-4f5c-da9e-f78f909a5259"
   },
   "outputs": [],
   "source": [
    "energy_net = nn.Sequential(nn.Linear(D, M), nn.ELU(),\n",
    "                               nn.Linear(M, M), nn.ELU(),\n",
    "                               nn.Linear(M, M), nn.ELU(),\n",
    "                               nn.Linear(M, K))\n",
    "\n",
    "# We initialize the full model\n",
    "model = EBM(energy_net, alpha=alpha, sigma=sigma, ld_steps=ld_steps, D=D)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "3SzTemY3NSxO"
   },
   "source": [
    "**Optimizer - here we use Adamax**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "R9TZtLVtLoWc"
   },
   "outputs": [],
   "source": [
    "# OPTIMIZER\n",
    "optimizer = torch.optim.Adamax([p for p in model.parameters() if p.requires_grad == True], lr=lr)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "dNf__W_ONVHA"
   },
   "source": [
    "**Training loop**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "KhqHgluGLqIC",
    "outputId": "c52fa1e4-3376-4bff-9f87-6f03613c4e42"
   },
   "outputs": [],
   "source": [
    "# Training procedure\n",
    "nll_val, error_val, gen_val = training(name=result_dir + name, max_patience=max_patience, num_epochs=num_epochs,\n",
    "                                       model=model, optimizer=optimizer,\n",
    "                                       training_loader=training_loader, val_loader=val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "-3XTxgEcNXfp"
   },
   "source": [
    "**The final evaluation**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "okK1mV_-LrRU",
    "outputId": "4664693f-742d-4453-94cf-d051d2efa9be"
   },
   "outputs": [],
   "source": [
    "test_loss, test_error, test_gen = evaluation(name=result_dir + name, test_loader=test_loader)\n",
    "f = open(result_dir + name + '_test_loss.txt', \"w\")\n",
    "f.write('NLL: ' + str(test_loss) + '\\nCA: ' + str(test_error) + '\\nGEN NLL: ' + str(test_gen))\n",
    "f.close()\n",
    "\n",
    "samples_real(result_dir + name, test_loader)\n",
    "samples_generated(result_dir + name, test_loader)\n",
    "\n",
    "plot_curve(result_dir + name, nll_val)\n",
    "plot_curve(result_dir + name, error_val, file_name='_ca_val_curve.pdf', color='r-')\n",
    "plot_curve(result_dir + name, gen_val, file_name='_gen_val_curve.pdf', color='g-')"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "vae_priors.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
